r"""
Run one replication.
"""
import gc
import time
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, Optional

import numpy as np
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.exceptions.warnings import BadInitialCandidatesWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.optim.optimize import optimize_acqf
from botorch.optim.utils import fix_features
from botorch.test_functions.base import BaseTestProblem, MultiObjectiveTestProblem
from botorch.utils.sampling import draw_sobol_samples

from .experiment_utils import (
    eval_problem,
    generate_initial_data,
    get_acqf,
    get_problem,
    initialize_model,
)
from .performance_evaluation import (
    update_all_performance_summary,
)
from gpytorch import settings
from torch import Tensor

warnings.simplefilter("always", BadInitialCandidatesWarning)
supported_labels = [
    "sobol",
    "ei",
    "qei",
    "ehvi",
    "qehvi",
    "qnehvi",
    "logei",
    "qlogei",
    "pi",
    "logpi",
    "qpi",
    "qlogpi",
    "logehvi",
    "qnehvis",
    "qlogehvi",
    "qlognehvi",
    "qlognehvis",
    "qlognehvis_pm",
    "kg",
    "gibbon",
    "jes",
    "cei",
    "logcei",
]

STACKABLE_RESULTS = ("in_sample_hv", "best_obj", "best_inferred_obj")


def run_one_replication(
    seed: int,
    label: str,
    function_name: str,
    batch_size: int,
    evaluation_budget: Optional[float] = None,
    n_initial_points: Optional[int] = None,
    optimization_kwargs: Optional[dict] = None,
    acqf_kwargs: Optional[dict] = None,
    model_kwargs: Optional[dict] = None,
    dtype: torch.dtype = torch.double,
    device: Optional[torch.device] = None,
    save_callback: Optional[Callable[[Tensor], None]] = None,
    problem_kwargs: Optional[Dict[str, Any]] = None,
    verbose: bool = True,
) -> None:
    r"""Run the BO loop for given number of iterations. Supports restarting of
    prematurely killed experiments.
    Args:
        seed: The experiment seed.
        label: The label / algorithm to use.
        batch_size: The q-batch size, i.e., number of parallel function evaluations.
        function_name: The name of the test function to use.
        evaluation_budget: Number of iterations of the BO loop to perform.
        n_initial_points: Number of initial evaluations to use. If None,
            2*(dim+1) points are used by default.
        optimization_kwargs: Arguments passed to `optimize_acqf`. Includes `num_restarts`
            and `raw_samples` and other optional arguments.
        acqf_kwargs: Keyword arguments for the acquisition function
        model_kwargs: Arguments for `initialize_model`. The default behavior is to use a (ModelListGP consisting of) noise-free FixedNoiseGP model(s).
        dtype: The tensor dtype to use.
        device: The device to use.
        save_callback: Method to save results to file.
        verbose: A boolean indicating whether to use verbose printing.
        problem_kwargs: Keyword arguments for the test problem.
    """
    assert label in supported_labels, "Label not supported!"
    # set up
    torch.manual_seed(seed)
    np.random.seed(seed)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tkwargs = {"dtype": dtype, "device": device}
    model_kwargs = model_kwargs or {}
    acqf_kwargs = acqf_kwargs or {}
    problem_kwargs = problem_kwargs or {}
    # get noise level
    noise_se = problem_kwargs.get("noise_se")
    if noise_se is not None:
        noise_se = torch.tensor(noise_se, **tkwargs)
        noise_var = noise_se.pow(2)
    else:
        noise_var = None
    optimization_kwargs = optimization_kwargs or {}
    # get test problem
    base_function = get_problem(name=function_name, **problem_kwargs)
    base_function.to(**tkwargs)
    is_moo = isinstance(base_function, MultiObjectiveTestProblem)

    # set default optimization parameters
    options = optimization_kwargs.get("options")
    if options is None:
        options = {}
        optimization_kwargs["options"] = options
    optimization_kwargs.setdefault("num_restarts", 20)
    options.setdefault("batch_limit", 5)
    options.setdefault("init_batch_limit", 32)
    optimization_kwargs.setdefault("raw_samples", 1024)

    # We work within the unit cube
    # evaluating the objectives will unnormalize data to the raw bounds
    # when necessary
    standard_bounds = torch.ones(2, base_function.dim, **tkwargs)
    standard_bounds[0] = 0

    # Get the initial data.
    if n_initial_points is None:
        n_initial_points = 2 * (base_function.dim + 1)
    X, Y_true, Y = generate_initial_data(
        n=n_initial_points,
        base_function=base_function,
        bounds=standard_bounds,
        tkwargs=tkwargs,
        noise_se=noise_se,
    )
    # standardize outputs
    # NOTE: Only standardizing objective dimensions, constraint dimensions will be
    # standardized separately to ensure that c(x) <= 0 for all feasible points.
    standardize_tf = Standardize(
        m=Y.shape[-1], outputs=list(range(base_function.num_objectives))
    )
    (
        stdized_Y,
        noise_var_stdized,
    ) = standardize_tf(Y, noise_var)
    standardize_tf.eval()
    print(f"{Y = }")

    # Set up
    wall_time = []
    gen_wall_time = []
    start_time = time.monotonic()
    # Fit the model.

    mll, model = initialize_model(
        train_X=X,
        train_Y=stdized_Y,
        # different from Y.shape[-1] if we have constraints
        num_objectives=base_function.num_objectives,
        noise_var=noise_var_stdized,
        **model_kwargs,
    )
    fit_gpytorch_mll(mll)

    # get inferred best Y, important to compare with KG, ES
    # update performance summary
    all_performance_summary = defaultdict(list)
    update_all_performance_summary(
        Y_true=Y_true,
        base_function=base_function,
        all_performance_summary=all_performance_summary,
        model=model,
        posterior_transform=acqf_kwargs.get("posterior_transform", None),
        mc_objective=acqf_kwargs.get("objective", None),
        bounds=standard_bounds,
    )

    # BO loop for as many iterations as needed.
    total_cost = 0
    costs = []
    while total_cost < evaluation_budget:
        if verbose:
            print(dict(model.named_parameters()))
        last_gen_wall_time = None if len(gen_wall_time) == 0 else gen_wall_time[-1]
        print("################################")
        print("Performance Summary")
        print("################################")
        if is_moo:
            print(
                f"Starting label {label}, seed {seed}, cost {round(total_cost, 3)}, "
                f"time: {time.monotonic()-start_time}, gen_wall_time: {last_gen_wall_time},"
                f" current in-sample HV: {all_performance_summary['in_sample_hv'][-1].item()}."
            )
        else:
            print(
                f"Starting label {label}, seed {seed}, cost {round(total_cost, 3)}, "
                f"time: {time.monotonic()-start_time}, gen_wall_time: {last_gen_wall_time},"
                f" current best obj: {all_performance_summary['best_obj'][-1].item()}, "
                # # f" inferred best obj: {all_performance_summary['best_inferred_obj'][-1].item()}."
            )

        # iterate over objectives and see if there is capacity
        gen_start_time = time.monotonic()
        candidates = generate_candidate(
            label=label,
            standard_bounds=standard_bounds,
            batch_size=batch_size,
            stdized_Y=stdized_Y,
            model=model,
            X=X,
            tkwargs=tkwargs,
            base_function=base_function,
            acqf_kwargs=acqf_kwargs,
            optimization_kwargs=optimization_kwargs,
            verbose=verbose,
            standardize_tf=standardize_tf,
        )
        # free memory
        del mll, model
        gc.collect()
        gen_wall_time.append(time.monotonic() - gen_start_time)

        # evaluate candidates
        new_y_true, new_y = eval_problem(
            candidates,
            base_function=base_function,
            noise_se=noise_se,
        )
        X = torch.cat([X, candidates], dim=0)
        Y_true = torch.cat([Y_true, new_y_true], dim=0)
        Y = torch.cat([Y, new_y], dim=0)
        # increment cost
        total_cost += batch_size
        costs.append(total_cost)

        # standardize outputs
        standardize_tf = Standardize(
            m=Y.shape[-1], outputs=list(range(base_function.num_objectives))
        )
        (
            stdized_Y,
            noise_var_stdized,
        ) = standardize_tf(Y, noise_var)
        standardize_tf.eval()
        wall_time.append(time.monotonic() - start_time)
        # Fit the model.
        mll, model = initialize_model(
            train_X=X,
            train_Y=stdized_Y,
            num_objectives=base_function.num_objectives,
            noise_var=noise_var_stdized,
            **model_kwargs,
        )
        fit_gpytorch_mll(mll)
        # update performance whenever cost passes the next integer
        update_all_performance_summary(
            Y_true=Y_true,
            base_function=base_function,
            all_performance_summary=all_performance_summary,
            model=model,
            posterior_transform=acqf_kwargs.get("posterior_transform", None),
            mc_objective=acqf_kwargs.get("objective", None),
            bounds=standard_bounds,
        )

    # Save the final output
    output_dict = {
        "label": label,
        "X": X.cpu(),
        "Y": Y.cpu(),
        "cost": torch.tensor(costs, dtype=dtype),
        "wall_time": torch.tensor(wall_time, dtype=dtype),
        "gen_wall_time": torch.tensor(gen_wall_time, dtype=dtype),
    }
    output_dict.update(
        {
            k: torch.stack(v, dim=0)
            if k in STACKABLE_RESULTS and v[0] is not None
            else v
            for k, v in all_performance_summary.items()
        }
    )
    save_callback(output_dict)


def generate_candidate(
    label: str,
    standard_bounds: Tensor,
    batch_size: int,
    stdized_Y: Tensor,
    model: Model,
    X: Tensor,
    tkwargs: Dict[str, Any],
    base_function: BaseTestProblem,
    acqf_kwargs: Dict[str, Any],
    optimization_kwargs: Dict[str, Any],
    verbose: bool,
    standardize_tf: Standardize,
) -> Tensor:
    """Generate candidates using the specified acquisition function.

    Args:
        label: The acquisition function name.
        standard_bounds: The bounds for optimization (the unit cube).
        batch_size: The q-batch size.
        stdized_Y: The stdized observed values.
        model: The fitted model.
        X: The previously evaluated designs (normalized to the unit cube).
        tkwargs: A dictionary containing the dtype and device.
        base_function: The test problem.
        acqf_kwargs: Keyword arguments for the acquisition function.
        optimization_kwargs: Keyword arguments for optimize_acqf.
        verbose: A boolean indicating whether to use verbose printing.
        standardize_tf: The standardize transform used for transforming
            the reference point.

    Returns:
        A `q x d`-dim tensor of candidates.
    """

    if label == "sobol":
        return (
            draw_sobol_samples(
                bounds=standard_bounds,
                n=1,
                q=batch_size,
            )
            .squeeze(0)
            .to(**tkwargs)
        )

    # Construct the acqf.
    with settings.cholesky_max_tries(9):
        acq_func = get_acqf(
            label=label,
            model=model,
            X_baseline=X,
            base_function=base_function,
            train_Y=stdized_Y,
            standardize_tf=standardize_tf,
            **acqf_kwargs,
        )

        bounds = standard_bounds
        ic_generator = None
        if optimization_kwargs.get("random_initialization", False):
            if label == "kg":
                raise Exception("Random initialization not compatible with KG!")
            ic_generator = random_initial_conditions

        if label == "kg":
            optimization_kwargs.pop("random_initialization", None)

        # make sure we're not passing this if it is False
        if not optimization_kwargs.get("random_initialization", False):
            optimization_kwargs.pop("random_initialization", False)

        # Optimize the acqf.
        torch.cuda.empty_cache()
        candidates, vals = optimize_acqf(
            acq_function=acq_func,
            bounds=bounds,
            q=batch_size,
            ic_generator=ic_generator,
            **optimization_kwargs,
        )
        if verbose:
            print(f"candidates: {candidates}")
            print(f"vals: {vals}")
        torch.cuda.empty_cache()
        # free memory
        del acq_func
        gc.collect()

    return candidates


def random_initial_conditions(
    acq_function: AcquisitionFunction,
    bounds: Tensor,
    q: int,
    num_restarts: int,
    raw_samples: int,
    fixed_features: Optional[Dict[int, float]] = None,
    **ignore: Any,
) -> Tensor:
    d = bounds.shape[-1]
    X = torch.rand(num_restarts, q, d, dtype=bounds.dtype, device=bounds.device)
    X = (bounds[1] - bounds[0]) * X + bounds[0]
    X = fix_features(X, fixed_features=fixed_features)
    return X
